{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d70339c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.stats import wasserstein_distance\n",
    "import helpers as ph\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import dataframe_image as dfi\n",
    "sns.set_style('whitegrid')\n",
    "styles = ph.VIS_STYLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29748ce8",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTS_DIR = f'./data/distributions/'\n",
    "CONTEXT = 'default'\n",
    "SAVEFIG = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "566d115a",
   "metadata": {},
   "source": [
    "## Load human and LM opinion distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4f60044",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_df = []\n",
    "for wave in ph.PEW_SURVEY_LIST:\n",
    "    SURVEY_NAME = f'American_Trends_Panel_W{wave}'\n",
    "    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_combined.csv'))\n",
    "    cdf['survey'] = f'ATP {wave}'\n",
    "    combined_df.append(cdf)\n",
    "combined_df = pd.concat(combined_df)\n",
    "combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',\n",
    "                                          axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cee0355",
   "metadata": {},
   "source": [
    "## Compute average representativeness across dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f1e19f",
   "metadata": {},
   "outputs": [],
   "source": [
    "KEYS = ['model_name', 'attribute', 'group', 'group_order', 'model_order']\n",
    "\n",
    "grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \\\n",
    "         .sort_values(by=['model_order', 'group_order'])\n",
    "grouped['Rep'] = 1 - grouped['WD']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00c8bc91",
   "metadata": {},
   "source": [
    "## Load coarse/fine-grained topics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1e6445f",
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_info = np.load('./data/human_resp/topic_mapping.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2abb16d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_df['topic_cg'] = combined_df.apply(lambda x: topic_info[x['question']]['cg'], axis=1)\n",
    "combined_df['topic_fg'] = combined_df.apply(lambda x: topic_info[x['question']]['fg'], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fb7a5c6",
   "metadata": {},
   "source": [
    "## Measure topic-level consistency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca8f6a82",
   "metadata": {},
   "outputs": [],
   "source": [
    "MD_TO_PALETTE = {'POLIDEOLOGY': 'coolwarm_r',\n",
    "                 'EDUCATION': 'PuBuGn',\n",
    "                 'INCOME': 'YlGnBu'}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5946f62f",
   "metadata": {},
   "source": [
    "### Coarse grained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "525e00e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_df = combined_df.explode(['topic_cg']).rename(columns={'topic_cg': 'topic'})\n",
    "\n",
    "KEYS = ['model_name', 'attribute', 'group', 'group_order', 'model_order', 'topic']\n",
    "\n",
    "topic_df = topic_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \\\n",
    "         .sort_values(by=['model_order', 'group_order'])\n",
    "topic_df['Rep'] = 1 - topic_df['WD']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27ab2a03",
   "metadata": {},
   "outputs": [],
   "source": [
    "MDS = ['POLIDEOLOGY', 'EDUCATION', 'INCOME']\n",
    "offsets = [1.22, 1.25, 1.225]\n",
    "fig, axarr = plt.subplots(1, len(MDS), figsize=(len(MDS)*4.5, 8), \n",
    "                          sharey=True)\n",
    "\n",
    "\n",
    "for mdi, MD in enumerate(MDS):\n",
    "\n",
    "    tdf = topic_df[topic_df['attribute'] == MD].sort_values(by='Rep')\n",
    "\n",
    "\n",
    "    size_df = tdf.groupby(['model_order', 'model_name', 'topic', 'attribute'], as_index=False)\\\n",
    "                    .agg({'Rep': list})\n",
    "    size_df['ratio'] = size_df.apply(lambda x: max(x['Rep']) / min(x['Rep']), axis=1)\n",
    "    size_df = size_df.drop(columns={'Rep'})\n",
    "\n",
    "    match_df = tdf.groupby(['model_order', 'model_name', 'topic', 'attribute'], \n",
    "                           as_index=False).last()\n",
    "\n",
    "    match_df = pd.merge(match_df, size_df)\\\n",
    "                .sort_values(by=['topic', 'model_order', 'group_order'])\n",
    "\n",
    "    group_list = match_df[['group_order', 'group']].drop_duplicates() \\\n",
    "              .sort_values(by='group_order')['group'].values\n",
    "    palette = MD_TO_PALETTE[MD]\n",
    "\n",
    "    group_to_order = {o: oi for oi, o in enumerate(group_list)}\n",
    "    group_to_color = {o: c for o, c in zip(group_list, \n",
    "                                           sns.color_palette(palette, len(group_list)))}\n",
    "\n",
    "\n",
    "    ax = axarr[mdi]\n",
    "    sns.scatterplot(data=match_df, y='topic', x='model_name', hue='group', \n",
    "                    size='ratio',\n",
    "                    palette=group_to_color, ax=ax, sizes=(20, 250), edgecolor='black')\n",
    "    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n",
    "    if mdi == 0:\n",
    "        ax.set_yticklabels(ax.get_yticklabels(), fontsize=12)\n",
    "        \n",
    "    ax.legend(title=MD, ncol=1, loc=\"upper center\", \n",
    "              bbox_to_anchor=(0.5, offsets[mdi]+0.5))\n",
    "    \n",
    "    h, l = ax.get_legend_handles_labels()\n",
    "    h, l = np.array(h), np.array(l)    \n",
    "    lvalid = [li for li, ll in enumerate(l) if ll in group_list]\n",
    "    h, l = h[lvalid], l[lvalid]\n",
    "    lidx = np.argsort([group_to_order[ll] for ll in l])\n",
    "    locs = [[3.48, 0.95],\n",
    "            [2.52, 0.65],\n",
    "            [1.3, 0.3]]\n",
    "    ax.legend(h[lidx], l[lidx], title=MD, ncol=1, loc=\"upper center\", bbox_to_anchor=locs[mdi])\n",
    "    \n",
    "    \n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    \n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "if SAVEFIG: plt.savefig('./figures/consistency_cg.png', bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85f65fbe",
   "metadata": {},
   "source": [
    "### Fine-grained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d83d957",
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_fg_df = combined_df.explode(['topic_fg']).rename(columns={'topic_fg': 'topic'})\n",
    "\n",
    "KEYS = ['model_name', 'attribute', 'group', 'group_order', 'model_order', 'topic']\n",
    "\n",
    "topic_fg_df = topic_fg_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \\\n",
    "         .sort_values(by=['model_order', 'group_order'])\n",
    "topic_fg_df['Rep'] = 1 - topic_fg_df['WD']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d34035d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "MDS = ['POLIDEOLOGY', 'EDUCATION', 'INCOME']\n",
    "offsets = [1.22, 1.25, 1.225]\n",
    "fig, axarr = plt.subplots(1, len(MDS), figsize=(len(MDS)*4.5, 16), \n",
    "                          sharey=True)\n",
    "\n",
    "\n",
    "for mdi, MD in enumerate(MDS):\n",
    "\n",
    "    tdf = topic_fg_df[topic_fg_df['attribute'] == MD].sort_values(by='Rep')\n",
    "\n",
    "\n",
    "    size_df = tdf.groupby(['model_order', 'model_name', 'topic', 'attribute'], as_index=False)\\\n",
    "                    .agg({'Rep': list})\n",
    "    size_df['ratio'] = size_df.apply(lambda x: max(x['Rep']) / min(x['Rep']), axis=1)\n",
    "    size_df = size_df.drop(columns={'Rep'})\n",
    "\n",
    "    match_df = tdf.groupby(['model_order', 'model_name', 'topic', 'attribute'], \n",
    "                           as_index=False).last()\n",
    "\n",
    "    match_df = pd.merge(match_df, size_df)\\\n",
    "                .sort_values(by=['topic', 'model_order', 'group_order'])\n",
    "\n",
    "    group_list = match_df[['group_order', 'group']].drop_duplicates() \\\n",
    "              .sort_values(by='group_order')['group'].values\n",
    "    palette = MD_TO_PALETTE[MD]\n",
    "\n",
    "    group_to_order = {o: oi for oi, o in enumerate(group_list)}\n",
    "    group_to_color = {o: c for o, c in zip(group_list, \n",
    "                                           sns.color_palette(palette, len(group_list)))}\n",
    "\n",
    "\n",
    "    ax = axarr[mdi]\n",
    "    sns.scatterplot(data=match_df, y='topic', x='model_name', hue='group', \n",
    "                    size='ratio',\n",
    "                    palette=group_to_color, ax=ax, sizes=(20, 250), edgecolor='black')\n",
    "    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)\n",
    "    if mdi == 0:\n",
    "        ax.set_yticklabels(ax.get_yticklabels(), fontsize=12)\n",
    "        \n",
    "    ax.legend(title=MD, ncol=1, loc=\"upper center\", \n",
    "              bbox_to_anchor=(0.5, offsets[mdi]+0.5))\n",
    "    \n",
    "    h, l = ax.get_legend_handles_labels()\n",
    "    h, l = np.array(h), np.array(l)    \n",
    "    lvalid = [li for li, ll in enumerate(l) if ll in group_list]\n",
    "    h, l = h[lvalid], l[lvalid]\n",
    "    lidx = np.argsort([group_to_order[ll] for ll in l])\n",
    "    locs = [[3.48, 0.95],\n",
    "            [2.52, 0.65],\n",
    "            [1.3, 0.3]]\n",
    "    ax.legend(h[lidx], l[lidx], title=MD, ncol=1, loc=\"upper center\", bbox_to_anchor=locs[mdi])\n",
    "    \n",
    "    \n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    \n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "if SAVEFIG: plt.savefig('./figures/consistency_fg.png', bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e67f85e",
   "metadata": {},
   "source": [
    "## Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faad30df",
   "metadata": {},
   "outputs": [],
   "source": [
    "KEYS = ['model_name', 'attribute', 'model_order']\n",
    "\n",
    "best_overall = grouped.sort_values(by='Rep')\\\n",
    "              .groupby(KEYS, as_index=False) \\\n",
    "              .last().rename(columns={'group': 'best_overall'})\n",
    "best_topic = topic_df.sort_values(by='Rep')\\\n",
    "              .groupby(KEYS + ['topic'], as_index=False) \\\n",
    "              .last().rename(columns={'group': 'best_topic'})\n",
    "consistency_df = pd.merge(best_overall[KEYS + ['best_overall']],\n",
    "                          best_topic[KEYS + ['topic', 'best_topic']])\n",
    "consistency_df['C'] = consistency_df.apply(lambda x: (x['best_topic'] == x['best_overall']),\n",
    "                                                 axis=1)\n",
    "consistency_df = consistency_df.groupby(KEYS, as_index=False).agg({'C': np.mean}) \\\n",
    "              .groupby(['model_name', 'model_order'], as_index=False).agg({'C': np.mean})\n",
    "consistency_df['Source'] = consistency_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',\n",
    "                                          axis=1)\n",
    "consistency_df = consistency_df.rename(columns={'model_name': ''}) \\\n",
    "                .sort_values(by='model_order')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf5d882a",
   "metadata": {},
   "outputs": [],
   "source": [
    "table = pd.pivot_table(consistency_df, \n",
    "                       columns=['Source', ''], \n",
    "                       values='C', \n",
    "                       sort=False)\n",
    "table_vis = table.style.background_gradient('Reds_r', axis=1).set_table_styles(styles)  \\\n",
    "                        .set_properties(**{\"font-size\":\"0.8rem\"}).format(precision=3)\n",
    "if SAVEFIG: table_vis.hide_index().export_png('./figures/consistency.png')\n",
    "display(table_vis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e345ec4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
